"""
Phase 4: Deep Structural Break Analysis — Sovereign Spreads
=============================================================
Probes the pre/post GFC structural break with:
  (a) Rolling windows (5yr, 7yr, 10yr) to pinpoint the break year
  (b) Quandt-Andrews sup-F test (all possible breakpoints)
  (c) Interaction with crisis dummy (GFC, eurozone crisis)
  (d) Annual coefficient evolution (year-by-year Z₁ interactions)
  (e) Rating agency behavior: does Z predict rating CHANGES differently?
  (f) Spread decomposition: term structure (3m vs 10y) break patterns

Output: output/tables/rolling_break.md, supf_test.md, crisis_interactions.md
"""

import sys
from pathlib import Path

import numpy as np
import pandas as pd
from scipy import stats

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/sovereign_spreads")
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.10: return '*'
    return ''


def run_model(df, dep_var, regressors, label, silent=False):
    regressors = [r for r in regressors if r in df.columns]
    if dep_var not in df.columns:
        return None
    sub = df.dropna(subset=[dep_var] + regressors).copy()
    if len(sub) < 50:
        if not silent:
            print(f"  [{label}] Insufficient obs ({len(sub)}) — skipping")
        return None
    gls = PanelGLS()
    gls.fit(sub[dep_var].values, sub[regressors].values,
            sub['iso3'].values, sub['year'].values)
    if not silent:
        print(f"  [{label}]  N={gls.n_obs}, countries={gls.n_countries}, R²={gls.r_squared:.4f}")
        for i, name in enumerate(regressors):
            sig = stars(gls.pvalues[i])
            print(f"    {name:<25} {gls.beta[i]:>10.4f} ({gls.se[i]:.4f}) {sig}")
    results = {
        'label': label, 'dep_var': dep_var,
        'n_obs': gls.n_obs, 'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
    }
    for i, name in enumerate(regressors):
        results[f'coef_{name}'] = gls.beta[i]
        results[f'se_{name}'] = gls.se[i]
        results[f'p_{name}'] = gls.pvalues[i]
    return results


def main():
    print("=" * 70)
    print("PHASE 4: Deep Structural Break Analysis")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "spread_panel.csv")
    print(f"Panel: {df['iso3'].nunique()} countries, {len(df):,} obs")

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls_rating = ['rgdp_growth', 'fiscal_bal_gdp', 'kaopen', 'govt_debt_gdp']
    controls_spread = ['rgdp_growth', 'fiscal_bal_gdp', 'kaopen', 'nfa_gdp_lag']

    # ═══════════════════════════════════════════════════════════════════
    # PART A: ROLLING WINDOWS — Rating Model
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART A: ROLLING WINDOWS — Z₁ on rating_numeric")
    print("=" * 70)

    vars_r = [v for v in demo_vars + controls_rating if v in df.columns]
    rolling_results = []

    for width in [5, 7, 10]:
        print(f"\n  --- Window width = {width} years ---")
        for start_year in range(1990, 2025 - width + 1):
            end_year = start_year + width - 1
            window = df[(df['year'] >= start_year) & (df['year'] <= end_year)]
            r = run_model(window, 'rating_numeric', vars_r,
                          f"w{width}: {start_year}-{end_year}", silent=True)
            if r:
                rolling_results.append({
                    'width': width, 'start': start_year, 'end': end_year,
                    'mid': (start_year + end_year) / 2,
                    'Z_1_coef': r.get('coef_Z_1', np.nan),
                    'Z_1_se': r.get('se_Z_1', np.nan),
                    'Z_1_p': r.get('p_Z_1', np.nan),
                    'n_obs': r['n_obs'],
                    'r_squared': r['r_squared'],
                })

    if rolling_results:
        rdf = pd.DataFrame(rolling_results)
        print(f"\n  Rolling results ({len(rdf)} windows):")
        print(f"  {'Width':>5} {'Window':>12} {'Z₁ coef':>10} {'SE':>8} {'p':>8} {'Sig':>4} {'N':>6}")
        for _, row in rdf.iterrows():
            sig = stars(row['Z_1_p'])
            print(f"  {int(row['width']):>5} {int(row['start'])}-{int(row['end']):>4} "
                  f"{row['Z_1_coef']:>10.2f} {row['Z_1_se']:>8.2f} {row['Z_1_p']:>8.4f} {sig:>4} {int(row['n_obs']):>6}")

        # Find sign change / significance shift
        w10 = rdf[rdf['width'] == 10].copy()
        if len(w10) > 3:
            w10['sig'] = w10['Z_1_p'] < 0.05
            transitions = w10[w10['sig'] != w10['sig'].shift(1)].dropna(subset=['sig'])
            if len(transitions) > 0:
                print(f"\n  ★ Significance transitions (10yr windows):")
                for _, t in transitions.iterrows():
                    direction = "→ significant" if t['sig'] else "→ insignificant"
                    print(f"    {int(t['start'])}-{int(t['end'])}: Z₁={t['Z_1_coef']:.2f} (p={t['Z_1_p']:.3f}) {direction}")

    # ═══════════════════════════════════════════════════════════════════
    # PART B: ROLLING WINDOWS — Spread Model
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART B: ROLLING WINDOWS — Z₁ on sovereign_spread")
    print("=" * 70)

    vars_s = [v for v in demo_vars + controls_spread if v in df.columns]
    spread_rolling = []

    for width in [7, 10]:
        for start_year in range(1990, 2025 - width + 1):
            end_year = start_year + width - 1
            window = df[(df['year'] >= start_year) & (df['year'] <= end_year)]
            r = run_model(window, 'sovereign_spread', vars_s,
                          f"sw{width}: {start_year}-{end_year}", silent=True)
            if r:
                spread_rolling.append({
                    'width': width, 'start': start_year, 'end': end_year,
                    'Z_1_coef': r.get('coef_Z_1', np.nan),
                    'Z_1_p': r.get('p_Z_1', np.nan),
                    'n_obs': r['n_obs'],
                })

    if spread_rolling:
        sdf = pd.DataFrame(spread_rolling)
        print(f"\n  Spread rolling results:")
        for _, row in sdf.iterrows():
            sig = stars(row['Z_1_p'])
            print(f"  w{int(row['width'])}: {int(row['start'])}-{int(row['end'])} "
                  f"Z₁={row['Z_1_coef']:>8.2f} p={row['Z_1_p']:.3f} {sig} N={int(row['n_obs'])}")

    # ═══════════════════════════════════════════════════════════════════
    # PART C: SUP-F TEST (Quandt-Andrews)
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART C: SUP-F TEST — All possible breakpoints")
    print("=" * 70)

    vars_chow = [v for v in demo_vars + controls_rating if v in df.columns]
    full = df.dropna(subset=['rating_numeric'] + vars_chow).copy()
    full = full[(full['year'] >= 1995) & (full['year'] <= 2020)]

    f_stats = []
    for break_year in range(2000, 2016):
        pre = full[full['year'] <= break_year]
        post = full[full['year'] > break_year]
        if len(pre) < 50 or len(post) < 50:
            continue

        gls_full = PanelGLS()
        gls_full.fit(full['rating_numeric'].values, full[vars_chow].values,
                     full['iso3'].values, full['year'].values)
        rss_full = np.sum((full['rating_numeric'].values -
                           full[vars_chow].values @ gls_full.beta) ** 2)

        gls_pre = PanelGLS()
        gls_pre.fit(pre['rating_numeric'].values, pre[vars_chow].values,
                    pre['iso3'].values, pre['year'].values)
        rss_pre = np.sum((pre['rating_numeric'].values -
                          pre[vars_chow].values @ gls_pre.beta) ** 2)

        gls_post = PanelGLS()
        gls_post.fit(post['rating_numeric'].values, post[vars_chow].values,
                     post['iso3'].values, post['year'].values)
        rss_post = np.sum((post['rating_numeric'].values -
                           post[vars_chow].values @ gls_post.beta) ** 2)

        k = len(vars_chow)
        n = len(full)
        denom = (rss_pre + rss_post) / (n - 2 * k)
        if denom > 0:
            F = ((rss_full - rss_pre - rss_post) / k) / denom
            p = 1 - stats.f.cdf(max(F, 0), k, n - 2 * k)
            f_stats.append({'break_year': break_year, 'F': F, 'p': p,
                            'n_pre': len(pre), 'n_post': len(post)})
            print(f"  Break at {break_year}: F={F:.2f}, p={p:.4f} {stars(p)}")

    if f_stats:
        fdf = pd.DataFrame(f_stats)
        best = fdf.loc[fdf['F'].idxmax()]
        print(f"\n  ★ Sup-F breakpoint: {int(best['break_year'])} (F={best['F']:.2f}, p={best['p']:.4f})")

    # ═══════════════════════════════════════════════════════════════════
    # PART D: CRISIS DUMMIES — GFC, Eurozone, COVID
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART D: CRISIS DUMMIES AND INTERACTIONS")
    print("=" * 70)

    crisis_results = []

    # GFC dummy (2008-2009)
    df['gfc'] = ((df['year'] >= 2008) & (df['year'] <= 2009)).astype(int)
    # Eurozone crisis (2010-2012)
    df['ez_crisis'] = ((df['year'] >= 2010) & (df['year'] <= 2012)).astype(int)
    # Post-GFC regime
    df['post_gfc'] = (df['year'] >= 2008).astype(int)
    # Post-eurozone-crisis regime
    df['post_ez'] = (df['year'] >= 2012).astype(int)

    # Z × post_gfc interaction on ratings
    for zv in demo_vars:
        df[f'{zv}_x_post_gfc'] = df[zv] * df['post_gfc']
    post_int = [f'{zv}_x_post_gfc' for zv in demo_vars]

    r = run_model(df, 'rating_numeric',
                  demo_vars + vars_r[3:] + ['post_gfc'] + post_int,
                  "D1: Z×post_GFC → rating")
    if r: crisis_results.append(r)

    # Z × post_ez interaction on ratings
    for zv in demo_vars:
        df[f'{zv}_x_post_ez'] = df[zv] * df['post_ez']
    ez_int = [f'{zv}_x_post_ez' for zv in demo_vars]

    r = run_model(df, 'rating_numeric',
                  demo_vars + vars_r[3:] + ['post_ez'] + ez_int,
                  "D2: Z×post_EZ_crisis → rating")
    if r: crisis_results.append(r)

    # Z × post_gfc on spreads
    r = run_model(df, 'sovereign_spread',
                  demo_vars + controls_spread + ['post_gfc'] + post_int,
                  "D3: Z×post_GFC → spread")
    if r: crisis_results.append(r)

    # Build table
    key_crisis_vars = demo_vars + ['post_gfc', 'post_ez'] + post_int + ez_int
    if crisis_results:
        md = ["# Crisis Interaction Models\n"]
        md.append("| Model | Dep Var | N | Countries | R² |")
        md.append("|---|---|---|---|---|")
        for r in crisis_results:
            md.append(f"| {r['label']} | {r['dep_var']} | {r['n_obs']:,} "
                      f"| {r['n_countries']} | {r['r_squared']:.3f} |")
        md.append("\n## Key Coefficients\n")
        md.append("| Model | Variable | Coef | SE | p-value | Sig |")
        md.append("|---|---|---|---|---|---|")
        for r in crisis_results:
            for var in key_crisis_vars:
                ckey = f'coef_{var}'
                if ckey in r:
                    p = r[f'p_{var}']
                    md.append(f"| {r['label']} | {var} | {r[ckey]:.4f} "
                              f"| {r[f'se_{var}']:.4f} | {p:.4f} | {stars(p)} |")
        out = TABLES_DIR / "crisis_interactions.md"
        out.write_text('\n'.join(md))
        print(f"\n  Saved: {out}")

    # ═══════════════════════════════════════════════════════════════════
    # PART E: THREE-PERIOD SPLIT
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART E: THREE-PERIOD SPLIT")
    print("=" * 70)

    # Pre-GFC (1990-2007), GFC+EZ crisis (2008-2012), Post-crisis (2013-2024)
    periods = [
        ("1990-2007", 1990, 2007),
        ("2008-2012", 2008, 2012),
        ("2013-2024", 2013, 2024),
    ]

    for label, y1, y2 in periods:
        sub = df[(df['year'] >= y1) & (df['year'] <= y2)]
        print(f"\n  --- {label} ---")
        r = run_model(sub, 'rating_numeric', vars_r, f"Rating {label}")
        r = run_model(sub, 'sovereign_spread', vars_s, f"Spread {label}")

    # ═══════════════════════════════════════════════════════════════════
    # PART F: RATING vs SPREAD CONVERGENCE
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART F: DO RATINGS AND SPREADS TELL THE SAME STORY?")
    print("=" * 70)

    # Rating residuals predicting spread (after demographics)
    sub = df.dropna(subset=['rating_numeric', 'sovereign_spread'] + vars_r).copy()
    if len(sub) >= 50:
        # First stage: Z → rating (get residual = unexplained rating)
        gls1 = PanelGLS()
        gls1.fit(sub['rating_numeric'].values, sub[vars_r].values,
                 sub['iso3'].values, sub['year'].values)
        sub['rating_resid'] = sub['rating_numeric'].values - sub[vars_r].values @ gls1.beta

        # Does the rating residual predict spreads?
        r = run_model(sub, 'sovereign_spread',
                      demo_vars + controls_spread + ['rating_resid'],
                      "F1: Z → spread | rating_resid")

        # Opposite: spread residual predicting rating
        gls2 = PanelGLS()
        gls2.fit(sub['sovereign_spread'].values, sub[vars_s].values,
                 sub['iso3'].values, sub['year'].values)
        sub['spread_resid'] = sub['sovereign_spread'].values - sub[vars_s].values @ gls2.beta

        r = run_model(sub, 'rating_numeric',
                      vars_r + ['spread_resid'],
                      "F2: Z → rating | spread_resid")

    # ═══════════════════════════════════════════════════════════════════
    # SUMMARY TABLE
    # ═══════════════════════════════════════════════════════════════════
    if rolling_results:
        md = ["# Rolling Window Analysis — Z₁ on Ratings\n"]
        md.append("| Width | Window | Z₁ Coef | SE | p-value | Sig | N | R² |")
        md.append("|---|---|---|---|---|---|---|---|")
        for row in rolling_results:
            sig = stars(row['Z_1_p'])
            md.append(f"| {int(row['width'])} | {int(row['start'])}-{int(row['end'])} "
                      f"| {row['Z_1_coef']:.2f} | {row['Z_1_se']:.2f} "
                      f"| {row['Z_1_p']:.4f} | {sig} | {int(row['n_obs'])} "
                      f"| {row['r_squared']:.3f} |")
        if f_stats:
            md.append(f"\n*Sup-F breakpoint: {int(best['break_year'])} "
                      f"(F={best['F']:.2f}, p={best['p']:.4f})*")
        out = TABLES_DIR / "rolling_break.md"
        out.write_text('\n'.join(md))
        print(f"\n  Saved: {out}")

    print("\n" + "=" * 70)
    print("Phase 4 complete.")
    print("=" * 70)


if __name__ == "__main__":
    main()
